为什么需要bucket

bucket就是一种编码思想,bucket的存在是为了减小计算量,从而可以减少模型的训练时间。当然,使用dynamic_rnn或rnn这两个接口也可以减少运算时间。bucket是用在使用在cell(input,state)这种古老的方法上的。

  • 每一个bucket都是一个固定的computation graph;
  • 其次,每一个sequence的pad都不是很多,对于计算资源的浪费很小
  • 再次,这样的实现很简单,就是一个给长度聚类,对于framework的要求很低
  1. 对train set:要对sequence的长度聚类,确保如何分配bucket。
  2. 数据依旧要填充到最大长度
  3. 对每个bucket都要建立一个模型,但是模型都是共享变量的
  4. 对每个模型都要都要计算loss,保存到list中
  5. 训练的时候,最小化对应bucket的loss

源码

https://github.com/tensorflow/tensorflow/blob/27711108b5fce2e1692f9440631a183b3808fa01/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py#L1118

Update

最新版本的 tensorflow 不需要使用 bucketing了,直接用 dynamic rnn 就好,它会根据每个batch自动计算最好的输出,不过要更定每个
example的 sequence length。
当然,现在有人可以做到自动计算 sequence length 了,可参考 tensorlayer.layers
这个方法也用在google 最新开源的 image captioning 例子里了。

dynamic rnn是如何解决效率问题的?

有了dynamic rnn, buket也仍然有用,因为后续logits loss计算的时候batch length小仍然会减少计算量。“One reason is that seq2seq was created before dynamic rnn was available. The other is that, even with dynamic rnn, it still helps for speed if your batches are organized by bucket”

dynamic rnn并不会提升效率.

The parameter sequence_length is optional and is used to copy-through state and zero-out outputs when past a batch element’s sequence length. So it’s more for correctness than performance.

参考